import torch
import torch.nn as nn
import numpy as np
import sys, os

hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__) + os.path.sep + '..')
sys.path.append(hello_pytorch_DIR)

from deepeye.tools.common_tools import set_seed

set_seed(1)

# ============================== nn.BatchNorm1d ==================================
# flag = 1
flag = 0
if flag:
    batch_size = 3
    num_features = 5
    momentum = 0.3

    features_shape = (1)
    
    feature_map = torch.ones(features_shape)    # 文档中 一个黄色框框理解为一个feature map     # 1D
    feature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0)   # 2D
    feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)       # 3D
    print('input data:\n{} shape is {}'.format(feature_maps_bs, feature_maps_bs.shape))

    bn = nn.BatchNorm1d(num_features=num_features, momentum=momentum)

    running_mean, running_var = 0, 1

    for i in range(2):
        outputs = bn(feature_maps_bs)

        print("\niteration:{}, running_mean:{}".format(i, bn.running_mean))
        print("iteration:{}, running_var:{}".format(i, bn.running_var))

        mean_t, var_t = 2, 0

        running_mean = (1-momentum)*running_mean + momentum*mean_t
        running_var = (1-momentum)*running_var + momentum*var_t

        print("\niteration:{}, 第二个特征的running_mean:{}".format(i, running_mean))
        print("iteration:{}, 第二个特征的running_var:{}".format(i, running_var))

# ============================== nn.BatchNorm2d ==================================
# flag = 1
flag = 0
if flag:
    batch_size = 3
    num_features = 6
    momentum = 0.3

    features_shape = (2, 2)

    feature_map = torch.ones(features_shape)                                                 # 2D
    feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0)  # 3D
    feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)           # 4D
    print('input data:\n{} shape is {}'.format(feature_maps_bs, feature_maps_bs.shape))

    bn = nn.BatchNorm2d(num_features=num_features, momentum=momentum)

    running_mean, running_var = 0, 1

    for i in range(2):
        outputs = bn(feature_maps_bs)

        print("\niteration:{}, running_mean.shape:{}".format(i, bn.running_mean.shape))  # 与特征数量是一一对应的，在特征维度上分别计算均值和方差
        print("iteration:{}, running_var.shape:{}".format(i, bn.running_var.shape))    # iteration:0,1.  running_var.shape:torch.Size([6])

        print("iteration:{}, weight.shape{}".format(i, bn.weight.shape))
        print("iteration:{}, bias.shape:{}".format(i, bn.bias.shape))


# ============================== nn.BatchNorm3d ==================================
flag = 1
# flag = 0
if flag:
    batch_size = 3
    num_features = 4
    momentum = 0.3

    features_shape = (2, 2, 3)

    feature = torch.ones(features_shape)                                                 # 3D
    feature_map = torch.stack([feature * (i + 1) for i in range(num_features)], dim=0)  # 4D
    feature_maps = torch.stack([feature_map for i in range(batch_size)], dim=0)       # 5D
    print('input data:\n{} shape is {}'.format(feature_maps, feature_maps.shape))

    bn = nn.BatchNorm3d(num_features=num_features, momentum=momentum)

    running_mean, running_var = 0, 1

    for i in range(2):
        outputs = bn(feature_maps)

        print("\niteration:{}, running_mean.shape:{}".format(i, bn.running_mean.shape))
        print("iteration:{}, running_var.shape:{}".format(i, bn.running_var.shape))

        print("iteration:{}, weight.shape{}".format(i, bn.weight.shape))
        print("iteration:{}, bias.shape:{}".format(i, bn.bias.shape))